{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Adding a New Model\n",
    "\n",
    "This tutorial shows how you can add a new model to the `neuralhydrology` modelzoo.\n",
    "As an example, we'll implement a GRU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import inspect\n",
    "from pathlib import Path\n",
    "from typing import Dict\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "from neuralhydrology.modelzoo import get_model\n",
    "from neuralhydrology.modelzoo.head import get_head\n",
    "from neuralhydrology.modelzoo.basemodel import BaseModel\n",
    "from neuralhydrology.modelzoo.template import TemplateModel\n",
    "from neuralhydrology.utils.config import Config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Template\n",
    "\n",
    "Every model has its own file in `neuralhydrology.modelzoo` and follows a common template that you can find [here](https://github.com/neuralhydrology/neuralhydrology/blob/master/neuralhydrology/modelzoo/template.py).\n",
    "\n",
    "The most important points about these templates are:\n",
    "\n",
    "- All models inherit from the `BaseModel` that's implemented in `neuralhydrology.modelzoo.basemodel`.\n",
    "- All models' constructors take just one argument, an instance of the configuration class (`Config`). The constructor initializes the model and its components.\n",
    "- Finally, each model implements its own logic in the `forward` method. This is where the actual magic happens: The forward method takes the input data during training and evaluation and uses it to generate a prediction.\n",
    "\n",
    "In the following steps, we'll go over the constructor and the forward method in more detail."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Adding a GRU Model\n",
    "\n",
    "So, let's follow that template and add a [GRU](https://en.wikipedia.org/wiki/Gated_recurrent_unit) model.\n",
    "Fortunately, there already exists a [GRU implementation](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html) in the PyTorch libary, so we can wrap our code around that existing model.\n",
    "This way, we can be pretty sure to get a correct and reasonably fast implementation without much effort.\n",
    "\n",
    "For the sake of brevity, we'll omit the docstrings in this example. If you actually implement a model for production use, you should always write the documentation right within your code.\n",
    "\n",
    "### GRU Components\n",
    "\n",
    "Every model's constructor receives a single argument: an instance of the run configuration.\n",
    "Based on this config, we'll construct the GRU.\n",
    "\n",
    "Like most our models, the GRU will consist of three components: \n",
    "\n",
    "- An optional input layer that acts as an embedding network for static or dynamic features. If used, the features will be passed through a fully-connected network before we pass them to the actual GRU. If no embedding is specified, this layer will do nothing.\n",
    "- The \"body\" that represents the actual GRU cell.\n",
    "- The \"head\" that acts as a final output layer.\n",
    "\n",
    "To maintain a modular architecture, the input and head layers should not be implemented inside the model. Instead, we should use the `InputLayer` in `neuralhydrology.modelzoo.inputlayer` and the `get_head` function in `neuralhydrology.modelzoo.head` which will automatically construct layers that fit to the run configuration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GRU(BaseModel):\n",
    "\n",
    "    # specify submodules of the model that can later be used for finetuning. Names must match class attributes\n",
    "    module_parts = ['embedding_net', 'gru', 'head']\n",
    "\n",
    "    def __init__(self, cfg: Config):\n",
    "\n",
    "        super(GRU, self).__init__(cfg=cfg)\n",
    "\n",
    "        # retrieve the input layer\n",
    "        self.embedding_net = InputLayer(cfg)\n",
    "\n",
    "        # create the actual GRU\n",
    "        self.gru = nn.GRU(input_size=self.embedding_net.output_size, hidden_size=cfg.hidden_size)\n",
    "\n",
    "        # add dropout between GRU and head\n",
    "        self.dropout = nn.Dropout(p=cfg.output_dropout)\n",
    "\n",
    "        # retrieve the model head\n",
    "        self.head = get_head(cfg=cfg, n_in=cfg.hidden_size, n_out=self.output_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "### Implementing the Forward Pass\n",
    "\n",
    "Now we have a class called GRU, but we haven't yet told the model how to process incoming data.\n",
    "That's what we do in the `forward` method.\n",
    "\n",
    "By convention, our models' `forward` method accepts and returns dictionaries that map names (strings) to tensors.\n",
    "The input dictionary (`data`) usually contains at least a key 'x_d' and possibly 'x_s' and 'x_one_hot'.\n",
    "We say \"usually\", because models that support simultaneous prediction at multiple timescales (e.g., MTS-LSTM) will\n",
    "get one 'x_d' and 'x_s' for each timescale, suffixed with the frequency identifier (e.g., 'x_d_1H' for hourly dynamic inputs).\n",
    "\n",
    "But for this example, let's assume a single-timescale model. Let's dive deeper into what each of the input values contain:\n",
    "\n",
    "| Key         | Shape                                     | Description                                                                                                                                                                                                                                                 |\n",
    "|:------------|:------------------------------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n",
    "| 'x_d'       | `[batch size, sequence length, features]` | the dynamic input data                                                                                                                                                                                                                                      |\n",
    "| 'x_s'       | `[batch size, features]`                  | static input features. These are the concatenation of what is defined in the run configuration under 'static_attributes' and 'evolving_attributes'. If not a single static or evolving attribute is defined in the config, 'x_s' will not be present.       |\n",
    "| 'x_one_hot' | `[batch size, number of basins]`          | one-hot encoding of the basins. If 'use_basin_id_encoding' is set to False in the run configuration, 'x_one_hot' will not be present.                                                                                                                       |\n",
    "\n",
    "Now, given these input data we're supposed to generate a prediction that we return as 'y_hat' (multi-timescale models would return 'y_hat_1H', ...).\n",
    "The returned 'y_hat' should contain a prediction for the _full_ input sequence (not just the last element), even if you're using sequence-to-one prediction.\n",
    "The loss will sort out which of these predictions actually need to be used in the current training configuration.\n",
    "All models should at least return 'y_hat', but we can return any other potentially useful information.\n",
    "In our case, we can additionally return the final hidden state that we'll receive from the PyTorch GRU implementation.\n",
    "The naming convention for hidden states is to call them 'h_n'.\n",
    "\n",
    "So, here we go:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n",
    "\n",
    "    # possibly pass dynamic and static inputs through embedding layers, then concatenate them\n",
    "    x_d = self.embedding_net(data, concatenate_output=True)    \n",
    "\n",
    "    # run the actual GRU\n",
    "    gru_output, h_n = self.gru(input=x_d)\n",
    "\n",
    "    # reshape to [batch_size, 1, n_hiddens]\n",
    "    h_n = h_n.transpose(0, 1)\n",
    "\n",
    "    pred = {'h_n': h_n}\n",
    "    \n",
    "    # add the final output as it's returned by the head to the prediction dict\n",
    "    # (this will contain the 'y_hat')\n",
    "    pred.update(self.head(self.dropout(gru_output.transpose(0, 1))))\n",
    "\n",
    "    return pred\n",
    "\n",
    "# usually, we'd implement the forward pass right where we define the class.\n",
    "# For this tutorial, we've broken it down into the constructor and the forward pass,\n",
    "# so now we'll just add the forward method to the GRU class:\n",
    "GRU.forward = forward"
   ]
  },
  {
   "source": [
    "As you see, much of the heavy lifting is being done by existing methods, so we just have to wire everything up.\n",
    "The input layer merges the static inputs (`data['x_s']` and/or `data['x_one_hot']`) to each step of the dynamic inputs (`data['x_d']`) and returns a single tensor that we can pass to the GRU cell.\n",
    "\n",
    "### Using the Model\n",
    "\n",
    "That's it! We now have a working GRU model that we can use to train and evaluate models.\n",
    "The only thing left is registering the model in the `get_model` method of `neuralhydrology.modelzoo` to make sure we can specify the model in a run configuration.\n",
    "\n",
    "Since GRU already exists in the modelzoo, it's already there:\n"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "def get_model(cfg: Config) -> nn.Module:\n    \"\"\"Get model object, depending on the run configuration.\n    \n    Parameters\n    ----------\n    cfg : Config\n        The run configuration.\n\n    Returns\n    -------\n    nn.Module\n        A new model instance of the type specified in the config.\n    \"\"\"\n    if cfg.model in SINGLE_FREQ_MODELS and len(cfg.use_frequencies) > 1:\n        raise ValueError(f\"Model {cfg.model} does not support multiple frequencies.\")\n\n    if cfg.model == \"cudalstm\":\n        model = CudaLSTM(cfg=cfg)\n    elif cfg.model == \"ealstm\":\n        model = EALSTM(cfg=cfg)\n    elif cfg.model == \"customlstm\":\n        model = CustomLSTM(cfg=cfg)\n    elif cfg.model == \"lstm\":\n        warnings.warn(\n            \"The `LSTM` class has been renamed to `CustomLSTM`. Support for `LSTM` will we dropped in the future.\",\n            FutureWarning)\n        model = CustomLSTM(cfg=cfg)\n    elif cfg.model == \"gru\":\n        model = GRU(cfg=cfg)\n    elif cfg.model == \"embcudalstm\":\n        model = EmbCudaLSTM(cfg=cfg)\n    elif cfg.model == \"mtslstm\":\n        model = MTSLSTM(cfg=cfg)\n    elif cfg.model == \"odelstm\":\n        model = ODELSTM(cfg=cfg)\n    else:\n        raise NotImplementedError(f\"{cfg.model} not implemented or not linked in `get_model()`\")\n\n    return model\n\n"
     ]
    }
   ],
   "source": [
    "print(inspect.getsource(get_model))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since GRU is registered as a model, you can now specify `model: gru` in the run configuration and use the model, just like any other.\n",
    "For an example of training and evaluating a model, take a look at the [introduction tutorial](https://neuralhydrology.readthedocs.io/en/latest/tutorials/introduction.html)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
